#############################################################################
# CODE FOR GFE 
# - Grouped slopes, free intercepts per county
# - Needs demeaned data per county
#
#############################################################################

#############################################################################
# Wrapper function to call for estimation
#############################################################################

gfe_m1_wrap = function(data,data_ori,G,beta0,gamma0,se_beta,se_gamma,niter,tol,Ns,sseed=1,save_name) {
  
  t0=Sys.time()
  
  #######################################
  # Starting values
  #######################################
  
  # Discretize the slopes
  qq= quantile(beta0$slope,probs=c(seq(0,1,length.out=G+1)))
  qq
  for (g in seq(1,G,1)) {
    beta0[slope >= qq[g],q:=g]
  }
  beta0v=c()
  for (g in seq(1,G,1)) {
    beta0v=c(beta0v,mean(beta0[q==g,slope]))
  }
  rm(qq,g)
  se_beta_new = max(beta0[,lapply(.SD,sd),by=c("q"),.SDcols="slope"]$slope)
  beta0v
  
  #######################################
  # Iterate 
  #######################################
  
  save_pos = paste0("Data_new/GFE/Gfe_res1_",save_name,".RData")
  res = group_FE(data=data,G=G,gamma0=gamma0,beta0=beta0v,se_gamma=se_gamma,
                 se_beta=se_beta,Ns=Ns,niter=niter,tol=tol,sseed=sseed,
                 save_name=save_pos,save_every=200)
  
  #######################################
  # Table of marginal effects
  #######################################
  
  dc = group_mg(G=G,data=data_ori,res=res)
  dim(dc)
  summary(dc)
  
  # Order groups in decreasing order of beta values
  dc = data_groups_order(dc)
  
  #######################################
  # Return
  #######################################
  
  t1=Sys.time()
  time_dif=t1-t0
  return(list(res=res,dc=dc,time_dif=time_dif))
  
}


#############################################################################
# Grouped FE for a given starting value and demeaned data
#############################################################################

group_FE_s = function(data,G,gamma0,beta0,niter=10000,tol=1e-7) {

  #
  dtm=copy(data)
  
  # Vector of objective functions to return
  obj_vec=c()
  
  y = dtm$yield
  Z = as.matrix(dtm[,.SD,.SDcols=names(gamma0)])
  X = as.matrix(dtm$ddr,ncol=1)  

  # Generate names of the group-varying slopes
  beta_names = paste0("ddr_g",seq(1,G,1))
  
  # Iterate
  j=0
  obj=sum(y^2)
  dch =1e8
  gamma_now = gamma0
  beta_now = beta0
  
  while (j<=niter & dch>tol) {
  
    ### Step 1: assignment step
    resg =data.table(fips=dtm$fips)
    
    for (g in seq(1,G,1)) {
      # Generate vector of residuals per i,t
      resg[,rnow := (y-beta_now[g]*X-Z%*%gamma_now)^2]
      setnames(resg,"rnow",paste0("res",g))
    }
    
    # Average per i over t
    resg = resg[,lapply(.SD,sum),by="fips"]
    
    # Choose group assignment as the minimum (or lower g if ties)  
    resg[,group:=1]
    resg[,res_min:=res1]
    for (g in seq(2,G,1)) {
      resg[,res_now:=resg[,.SD,.SDcols=paste0("res",g)]]
      resg[res_now<res_min,group:=g]
      resg[res_now<res_min,res_min:=res_now]
    }    
  
    # Paste the chosen group to the original data 
    dt_now = merge(dtm,resg[,.(fips,group)],by="fips",all.x=TRUE)

    ### Step 2: update step
    # Generate dummy of group x ddr
    for (g in seq(1,G,1)) {
      dt_now[,ddr_now := (group==g)*ddr]
      setnames(dt_now,"ddr_now",paste0("ddr_g",g))
    }
    
    # Regress
    form = as.formula(paste("yield ~ ",
                            paste(paste0("ddr_g",seq(1,G,1)), collapse= "+"),"+",
                            paste(names(gamma0), collapse= "+"),"-1"))
    reg_now = feols(fml=form,data=dt_now)
    
    # Update coefficients (note we loose extra state-year FE per period, so there are NAs in gamma)
    gamma_now=reg_now$coefficients[names(gamma0)]
    gamma_now[is.na(gamma_now)]=0
    gamma_now = setNames(gamma_now,names(gamma0))
    beta_now=reg_now$coefficients[beta_names]
     
    # Objective value
    obj_now = reg_now$ssr
    obj_vec=c(obj_vec,obj_now)
    dch = abs(obj-obj_now)
    obj = obj_now
    j=j+1 
    
  }
  
  return(list(iter=j,obj=obj_vec,beta=beta_now,gamma=gamma_now,group=resg[,.SD,.SDcols=c("fips","group")]))
}

#############################################################################
# Grouped FE across starting values randomly generated
#############################################################################

group_FE = function(data,G,gamma0,beta0,se_gamma,se_beta,Ns,niter,tol,sseed=1,save_name,save_every=500) {

  dtm = copy(data)
  
  # Set seed
  set.seed(sseed)
  
  # Matrices to save results
  R_obj <<- c()
  R_iter <<- c()
  R_gamma <<- matrix(NA,nrow=1,ncol=length(gamma0)) 
  R_beta <<- matrix(NA,nrow=1,ncol=length(beta0)) 
  R_groups <<- matrix(sort(unique(dtm$fips)),ncol=1)
    
  # For the starting value given
  res = group_FE_s(data=dtm,G=G,gamma0=gamma0,beta0=beta0,niter=niter,tol=tol)
  R_obj <<- c(R_obj,res$obj[length(res$obj)])
  R_gamma <<- rbind(R_gamma,res$gamma)
  R_beta <<- rbind(R_beta,res$beta)
  R_groups <<- cbind(R_groups,res$group$group)
  R_iter <<- cbind(R_iter ,res$iter)
  
  # Loop over simulated random starts
  for (k in seq(1,Ns-1,1)) {
    
    print(k)

    # Starting point 
    beta_now = beta0 + rnorm(n=G,mean=0,sd=se_beta)
    gamma_now = as.numeric(gamma0 +  rmvnorm(n=1, mean = rep(0,length(gamma0)), sigma = diag(se_gamma^2)))
    names(gamma_now)=names(gamma0)
    
    # Grouped FE at this starting point
    gfe = group_FE_s(data=dtm,G=G,gamma0=gamma_now,beta0=beta_now,niter=niter,tol=tol)
    
    # Save results
    R_obj <<- c(R_obj,gfe$obj[length(gfe$obj)])
    R_iter <<- c(R_iter,gfe$iter)
    R_gamma <<- rbind(R_gamma,gfe$gamma)
    R_beta <<- rbind(R_beta,gfe$beta)
    R_groups <<- cbind(R_groups,gfe$group$group)
    
    if (k%%save_every==0) {
      save.image(save_name)
    }
  }  
    
  # 
  R_gamma <<- R_gamma[-1,]
  R_beta <<- R_beta[-1,]
  
  # Sort all results from best to worse in R_obj
  conv=as.numeric(R_iter[which.min(R_obj)]<niter)
  S = data.table(obs=R_obj)
  S[,pos:=.I]
  setkey(S,obs)
  pos =S$pos
  R_obj=R_obj[pos]
  R_iter=R_iter[pos]
  R_gamma=R_gamma[pos,]
  R_beta=R_beta[pos,]
  R_sub = R_groups[,-1]
  R_sub = R_sub[,pos]
  R_groups = cbind(R_groups[,1],R_sub)
  R_groups =data.table(R_groups)
  setnames(R_groups,c("fips",paste0("V",seq(1,Ns,1))))
  
  return(list(conv=conv,obj=R_obj,gamma=R_gamma,beta=R_beta,groups=R_groups,iter=R_iter))
  
}

#############################################################################
# Function that returns a data set of counties with the coefficients
#############################################################################

group_fips = function(G,res) {

  # Generate a data base per county with the group assignment
  dc = res$groups[,.SD,.SDcols=c("fips","V1")]
  setnames(dc,c("fips","group"))
  
  # Add to the groups the relevant coefficients
  for (g in seq(1,G,1)) {
    dc[,paste0("beta_g",g):=res$beta[1,g]]
  }
  
  # Generate beta of each county
  dc[,beta:=0]
  for (g in seq(1,G,1)) {
    dc[,beta_now:=.SD,.SDcols=paste0("beta_g",g)]
    dc[is.na(beta_now)==TRUE,beta_now:=0]
    dc[,beta:=beta+beta_now*(group==g)]
  }
  dc[,beta_now:=NULL]

  return(dc)
}

#############################################################################
# Function that returns a data county-year with marginal effects
#############################################################################

group_mg = function(G,data,res) {
  
  # Data at the county level with coefficients
  dc = group_fips(G=G,res=res)
  
  # Bring the temperature per county-year
  dt = merge(data[,.SD,.SDcols=c("fips","year","ddr","corn_area")],dc,by="fips",all.x=TRUE)

  # Generate marginal effect
  dt[,mg:=beta]
  
  return(dt)
}

  
#############################################################################
# Function that returns a data set of mg effects per county for the 
# 10 best starting values
#############################################################################

group_fips_10b = function(G,data,res) {
  
  #
  dtm=data[,.SD,.SDcols=c("fips","year","corn_area","ddr")]
  dt_tot=copy(dtm)
  
  # Loop over results 
  for (k in seq(1,10,1)) {
    
    # Generate a data base per county with the group assignment
    dc = res$groups[,.SD,.SDcols=c("fips",paste0("V",k))]
    setnames(dc,c("fips","group"))
    
    # Add to the groups the relevant coefficients
    for (g in seq(1,G,1)) {
      dc[,paste0("beta_g",g):=res$beta[k,g]]
    }
    
    # Generate beta of each county
    dc[,beta:=0]
    for (g in seq(1,G,1)) {
      dc[,beta_now:=.SD,.SDcols=paste0("beta_g",g)]
      dc[is.na(beta_now)==TRUE,beta_now:=0]
      dc[,beta:=beta+beta_now*(group==g)]
    }
    dc[,beta_now:=NULL]
    
    # Bring the temperature per county-year
    dc = merge(dc,dtm,by="fips",all=TRUE)
    
    # Generate marginal effect
    dc[,mg:=beta]
    
    # Add to dt_tot
    dt_tot = merge(dt_tot,dc[,.(fips,year,mg)],by=c("fips","year"),all=TRUE)
    setnames(dt_tot,"mg",paste0("mg_b",k))
      
  }  
  return(dt_tot)
}

#############################################################################
# Function that reorders group variable in dc data as decreasing order in 
# beta
#############################################################################

data_groups_order = function(dc) {
  
  dc_now=copy(dc)
  dc_now[,group:=NULL]
  beta_s = sort(unique(dc_now$beta),decreasing=TRUE)
  for (g in seq(1,length(beta_s))) {
    dc_now[beta==beta_s[g],group:=g]
  }
  return(dc_now)
  
}

#############################################################################
# Function that predicts
#############################################################################

predict_gfe = function(res,dc,data,data_pred=NA,G) {
  
  dt = copy(data)
  
  # Parameters
  gamma = res$gamma[1,]
  ngamma = colnames(res$gamma)

  # Estimate the FE per county
  dt = merge(dt,dc[,.(fips,year,beta)],by=c("fips","year"),all=TRUE)
  dt[,xb:=beta*ddr]
  Xb = as.matrix(dt$xb,ncol=1) 
  Z = as.matrix(dt[,.SD,.SDcols=ngamma])
  dt[,yhat := Xb + Z%*%gamma]
  dt[,alpha:=mean(yield-yhat),by="fips"]
  
  # Prediction in new data set
  if (length(data_pred)==1) {
    dtp = copy(data)
  }
  if (length(data_pred)>1) {
    dtp = copy(data_pred)
    ds = dt[,lapply(.SD,mean),by="fips",.SDcols=c("beta","alpha")]
    dtp = merge(dtp,ds,by="fips")
  }
  Z = as.matrix(dtp[,.SD,.SDcols=ngamma])
  dtp[,yhat:=Z%*%gamma]
  dtp[,yhat:= yhat + beta*ddr + alpha]
  
  # Subset of variables
  dt = dt[,.SD,.SDcols=c("fips","year","alpha","beta","yield","yhat")]
  dtp = dtp[,.SD,.SDcols=c("fips","year","alpha","beta","yield","yhat")]
  
  return(list(dt=dt,dtp=dtp)) 
}